package tools

import (
	"fmt"
	"reflect"
	"sync/atomic"
	"unsafe"
)

// AtomicInt 原子化操作的
type AtomicInt[T int32 | uint32 | int64 | uint64] struct {
	data T
	kind reflect.Kind
}

func NewAtomicInt[T int32 | uint32 | int64 | uint64](data T) *AtomicInt[T] {
	return &AtomicInt[T]{
		data: data,
	}
}

// Get 获取
func (a *AtomicInt[T]) Get() T {
	a.generateKind()
	switch a.kind {
	case reflect.Int32:
		addrValue := (*int32)(unsafe.Pointer(&a.data))
		return T(atomic.LoadInt32(addrValue))
	case reflect.Uint32:
		addrValue := (*uint32)(unsafe.Pointer(&a.data))
		return T(atomic.LoadUint32(addrValue))
	case reflect.Int64:
		addrValue := (*int64)(unsafe.Pointer(&a.data))
		return T(atomic.LoadInt64(addrValue))
	case reflect.Uint64:
		addrValue := (*uint64)(unsafe.Pointer(&a.data))
		return T(atomic.LoadUint64(addrValue))
	default:
		panic(fmt.Sprintf("not supported type %s", a.kind.String()))
	}
}

// Set 设置成新的值
func (a *AtomicInt[T]) Set(newData T) {
	a.generateKind()
	switch a.kind {
	case reflect.Int32:
		addrValue := (*int32)(unsafe.Pointer(&a.data))
		newValue := int32(newData)
		for {
			getData := a.Get()
			oldValue := int32(getData)
			if atomic.CompareAndSwapInt32(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Uint32:
		addrValue := (*uint32)(unsafe.Pointer(&a.data))
		newValue := uint32(newData)
		for {
			getData := a.Get()
			oldValue := uint32(getData)
			if atomic.CompareAndSwapUint32(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Int64:
		addrValue := (*int64)(unsafe.Pointer(&a.data))
		newValue := int64(newData)
		for {
			getData := a.Get()
			oldValue := int64(getData)
			if atomic.CompareAndSwapInt64(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Uint64:
		addrValue := (*uint64)(unsafe.Pointer(&a.data))
		newValue := uint64(newData)
		for {
			getData := a.Get()
			oldValue := uint64(getData)
			if atomic.CompareAndSwapUint64(addrValue, oldValue, newValue) {
				break
			}
		}
	default:
		panic(fmt.Sprintf("not supported type %s", a.kind.String()))
	}
}

func (a *AtomicInt[T]) Increase(delta T) {
	a.compareAndSwap(delta, false)
}

func (a *AtomicInt[T]) IncreaseOne() {
	a.Increase(1)
}

func (a *AtomicInt[T]) Decrease(delta T) {
	a.compareAndSwap(delta, true)
}

func (a *AtomicInt[T]) DecreaseOne() {
	a.compareAndSwap(1, true)
}

func (a *AtomicInt[T]) generateKind() {
	if a.kind != reflect.Invalid {
		return
	}
	a.kind = reflect.TypeOf(a.data).Kind()
}

func (a *AtomicInt[T]) compareAndSwap(delta T, isDecrease bool) {
	a.generateKind()
	switch a.kind {
	case reflect.Int32:
		addrValue := (*int32)(unsafe.Pointer(&a.data))
		deltaValue := int32(delta)
		for {
			getData := a.Get()
			oldValue := int32(getData)
			var newValue int32
			if isDecrease {
				newValue = oldValue - deltaValue
			} else {
				newValue = oldValue + deltaValue
			}
			if atomic.CompareAndSwapInt32(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Uint32:
		addrValue := (*uint32)(unsafe.Pointer(&a.data))
		deltaValue := uint32(delta)
		for {
			getData := a.Get()
			oldValue := uint32(getData)
			var newValue uint32
			if isDecrease {
				newValue = oldValue - deltaValue
			} else {
				newValue = oldValue + deltaValue
			}
			if atomic.CompareAndSwapUint32(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Int64:
		addrValue := (*int64)(unsafe.Pointer(&a.data))
		deltaValue := int64(delta)
		for {
			getData := a.Get()
			oldValue := int64(getData)
			var newValue int64
			if isDecrease {
				newValue = oldValue - deltaValue
			} else {
				newValue = oldValue + deltaValue
			}
			if atomic.CompareAndSwapInt64(addrValue, oldValue, newValue) {
				break
			}
		}
	case reflect.Uint64:
		addrValue := (*uint64)(unsafe.Pointer(&a.data))
		deltaValue := uint64(delta)
		for {
			getData := a.Get()
			oldValue := uint64(getData)
			var newValue uint64
			if isDecrease {
				newValue = oldValue - deltaValue
			} else {
				newValue = oldValue + deltaValue
			}
			if atomic.CompareAndSwapUint64(addrValue, oldValue, newValue) {
				break
			}
		}
	default:
		panic(fmt.Sprintf("not supported type %s", a.kind.String()))
	}
}

//	// IncreaseOne 增加1
//	func(a *AtomicInt[T]) IncreaseOne() {
//		atomic.AddInt64(&a.data, 1)
//	}
//
//	// Increase 增加指定的值
//	func(a *AtomicInt[T]) Increase(v int64) {
//		atomic.AddInt64(&a.data, v)
//	}
//
//	// DecreaseOne 减去1
//	func(a *AtomicInt[T]) DecreaseOne() {
//		atomic.AddInt64(&a.data, -1)
//	}
//
//	// Decrease 减去指定的值
//	func(a *AtomicInt[T]) Decrease(v int64) {
//		atomic.AddInt64(&a.data, v*(-1))
//	}
//
//	func(a *AtomicInt[T]) compareAndSwap(delta T){
//
//}
